{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "29c26766",
   "metadata": {},
   "source": [
    "# 张量的保存和加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d709ab87",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "588159e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8608, 0.6997, 0.4133, 0.6113, 0.5393, 0.8223])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(6)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "43aa15a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(a,\"model/tensor_a\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3589ec6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8608, 0.6997, 0.4133, 0.6113, 0.5393, 0.8223])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load(\"model/tensor_a\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "05bf71bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),\n",
       " tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),\n",
       " tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(6)\n",
    "b = torch.rand(6)\n",
    "c = torch.rand(6)\n",
    "[a,b,c]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6d5eb6ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save([a,b,c],\"model/tensor_abc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5bac7f09",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),\n",
       " tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),\n",
       " tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load(\"model/tensor_abc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1391dbc8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),\n",
       " 'b': tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),\n",
       " 'c': tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor_dict= {'a':a,'b':b,'c':c}\n",
    "tensor_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4a87f7ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(tensor_dict,\"model/tensor_dict_abc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c374d5e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),\n",
       " 'b': tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),\n",
       " 'c': tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load(\"model/tensor_dict_abc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5c3641a",
   "metadata": {},
   "source": [
    "# 模型的保存与加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "592d73a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "import torch.nn as nn \n",
    "import torch.optim as optim\n",
    "\n",
    "# 定义 MLP 网络  继承nn.Module\n",
    "class MLP(nn.Module):\n",
    "    \n",
    "    # 初始化方法\n",
    "    # input_size输入数据的维度    \n",
    "    # hidden_size 隐藏层的大小\n",
    "    # num_classes 输出分类的数量\n",
    "    def __init__(self, input_size, hidden_size, num_classes):\n",
    "        # 调用父类的初始化方法\n",
    "        super(MLP, self).__init__()\n",
    "        # 定义第1个全连接层  \n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        # 定义激活函数\n",
    "        self.relu = nn.ReLU()\n",
    "        # 定义第2个全连接层\n",
    "        self.fc2 = nn.Linear(hidden_size, hidden_size)\n",
    "        # 定义第3个全连接层\n",
    "        self.fc3 = nn.Linear(hidden_size, num_classes)\n",
    "        \n",
    "    # 定义forward函数\n",
    "    # x 输入的数据\n",
    "    def forward(self, x):\n",
    "        # 第一层运算\n",
    "        out = self.fc1(x)\n",
    "        # 将上一步结果送给激活函数\n",
    "        out = self.relu(out)\n",
    "        # 将上一步结果送给fc2\n",
    "        out = self.fc2(out)\n",
    "        # 同样将结果送给激活函数\n",
    "        out = self.relu(out)\n",
    "        # 将上一步结果传递给fc3\n",
    "        out = self.fc3(out)\n",
    "        # 返回结果\n",
    "        return out\n",
    "    \n",
    "# 定义参数    \n",
    "input_size = 28 * 28  # 输入大小\n",
    "hidden_size = 512  # 隐藏层大小\n",
    "num_classes = 10  # 输出大小（类别数） \n",
    "\n",
    "# 初始化MLP    \n",
    "model = MLP(input_size, hidden_size, num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c1fbe9d",
   "metadata": {},
   "source": [
    "### 方式1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "851bcb3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型参数\n",
    "torch.save(model.state_dict(),\"model/mlp_state_dict.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d555cb3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 读取保存的模型参数\n",
    "mlp_state_dict = torch.load(\"model/mlp_state_dict.pth\")\n",
    "\n",
    "# 新实例化一个MLP模型\n",
    "model_load = MLP(input_size,hidden_size,num_classes)\n",
    "\n",
    "# 调用load_state_dict方法 传入读取的参数\n",
    "model_load.load_state_dict(mlp_state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff83264f",
   "metadata": {},
   "source": [
    "### 方式2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f10908c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存整个模型\n",
    "torch.save(model,\"model/mlp_model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8de54658",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载整个模型\n",
    "mlp_load = torch.load(\"model/mlp_model.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ef9f302",
   "metadata": {},
   "source": [
    "### 方式3 ： checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ddf837",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存参数\n",
    "torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'loss': loss,\n",
    "            ...\n",
    "            }, PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "098aa0b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载参数\n",
    "model = TheModelClass(*args, **kwargs)\n",
    "optimizer = TheOptimizerClass(*args, **kwargs)\n",
    "\n",
    "checkpoint = torch.load(PATH)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "epoch = checkpoint['epoch']\n",
    "loss = checkpoint['loss']\n",
    "\n",
    "model.eval()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
